/*
 * Copyright 2006-2010, Haiku, Inc. All Rights Reserved.
 * Distributed under the terms of the MIT License.
 *
 * Authors:
 *		Axel Dörfler, axeld@pinc-software.de
 */


#include "ipv4.h"
#include "ipv4_address.h"
#include "multicast.h"

#include <net_datalink.h>
#include <net_datalink_protocol.h>
#include <net_device.h>
#include <net_protocol.h>
#include <net_stack.h>
#include <NetBufferUtilities.h>
#include <ProtocolUtilities.h>

#include <KernelExport.h>
#include <util/AutoLock.h>
#include <util/list.h>
#include <util/DoublyLinkedList.h>
#include <util/MultiHashTable.h>

#include <netinet/in.h>
#include <netinet/ip.h>
#include <new>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <utility>


//#define TRACE_IPV4
#ifdef TRACE_IPV4
#	define TRACE(format, args...) \
		dprintf("IPv4 [%" B_PRIdBIGTIME "] " format "\n", system_time() , \
			##args)
#	define TRACE_SK(protocol, format, args...) \
		dprintf("IPv4 [%" B_PRIdBIGTIME "] %p " format "\n", system_time(), \
			protocol , ##args)
#	define TRACE_ONLY(x) x
#else
#	define TRACE(args...) ;
#	define TRACE_SK(args...) ;
#	define TRACE_ONLY(x)
#endif


#define MAX_HASH_FRAGMENTS 		64
	// slots in the fragment packet's hash
#define FRAGMENT_TIMEOUT		60000000LL
	// discard fragment after 60 seconds


typedef DoublyLinkedList<struct net_buffer,
	DoublyLinkedListCLink<struct net_buffer> > FragmentList;

typedef NetBufferField<uint16, offsetof(ipv4_header, checksum)> IPChecksumField;

struct ipv4_packet_key {
	in_addr_t	source;
	in_addr_t	destination;
	uint16		id;
	uint8		protocol;
};


class FragmentPacket {
public:
								FragmentPacket(const ipv4_packet_key& key);
								~FragmentPacket();

			status_t			AddFragment(uint16 start, uint16 end,
									net_buffer* buffer, bool lastFragment);
			status_t			Reassemble(net_buffer* to);

			bool				IsComplete() const
									{ return fReceivedLastFragment
										&& fBytesLeft == 0; }

			const ipv4_packet_key& Key() const { return fKey; }
			FragmentPacket*&	HashTableLink() { return fNext; }

	static	void				StaleTimer(struct net_timer* timer, void* data);

private:
			FragmentPacket*		fNext;
			struct ipv4_packet_key fKey;
			uint32				fIndex;
			bool				fReceivedLastFragment;
			int32				fBytesLeft;
			FragmentList		fFragments;
			net_timer			fTimer;
};


struct FragmentHashDefinition {
	typedef ipv4_packet_key KeyType;
	typedef FragmentPacket ValueType;

	size_t HashKey(const KeyType& key) const
	{
		return (key.source ^ key.destination ^ key.protocol ^ key.id);
	}

	size_t Hash(ValueType* value) const
	{
		return HashKey(value->Key());
	}

	bool Compare(const KeyType& key, ValueType* value) const
	{
		const ipv4_packet_key& packetKey = value->Key();

		return packetKey.id == key.id
			&& packetKey.source == key.source
			&& packetKey.destination == key.destination
			&& packetKey.protocol == key.protocol;
	}

	ValueType*& GetLink(ValueType* value) const
	{
		return value->HashTableLink();
	}
};

typedef BOpenHashTable<FragmentHashDefinition, false, true> FragmentTable;


class RawSocket
	: public DoublyLinkedListLinkImpl<RawSocket>, public DatagramSocket<> {
public:
								RawSocket(net_socket* socket);
};

typedef DoublyLinkedList<RawSocket> RawSocketList;

typedef MulticastGroupInterface<IPv4Multicast> IPv4GroupInterface;
typedef MulticastFilter<IPv4Multicast> IPv4MulticastFilter;

struct MulticastStateHash {
	typedef std::pair<const in_addr* , uint32> KeyType;
	typedef IPv4GroupInterface ValueType;

	size_t HashKey(const KeyType &key) const
		{ return key.first->s_addr ^ key.second; }
	size_t Hash(ValueType* value) const
		{ return HashKey(std::make_pair(&value->Address(),
			value->Interface()->index)); }
	bool Compare(const KeyType &key, ValueType* value) const
		{ return value->Interface()->index == key.second
			&& value->Address().s_addr == key.first->s_addr; }
	bool CompareValues(ValueType* value1, ValueType* value2) const
		{ return value1->Interface()->index == value2->Interface()->index
			&& value1->Address().s_addr == value2->Address().s_addr; }
	ValueType*& GetLink(ValueType* value) const { return value->MulticastGroupsHashLink(); }
};


struct ipv4_protocol : net_protocol {
	ipv4_protocol()
		:
		raw(NULL),
		multicast_filter(this)
	{
	}

	~ipv4_protocol()
	{
		delete raw;
	}

	RawSocket*			raw;
	uint8				service_type;
	uint8				time_to_live;
	uint8				multicast_time_to_live;
	bool				multicast_loopback;
	uint32				flags;
	struct sockaddr*	multicast_address; // for IP_MULTICAST_IF

	IPv4MulticastFilter	multicast_filter;
};

// protocol flags
#define IP_FLAG_HEADER_INCLUDED		0x01
#define IP_FLAG_RECEIVE_DEST_ADDR	0x02


static const int kDefaultTTL = 254;
static const int kDefaultMulticastTTL = 1;
static const bool kDefaultMulticastLoopback = true;


extern net_protocol_module_info gIPv4Module;
	// we need this in ipv4_std_ops() for registering the AF_INET domain

net_stack_module_info* gStackModule;
net_buffer_module_info* gBufferModule;

static struct net_domain* sDomain;
static net_datalink_module_info* sDatalinkModule;
static net_socket_module_info* sSocketModule;
static int32 sPacketID;
static RawSocketList sRawSockets;
static mutex sRawSocketsLock;
static mutex sFragmentLock;
static FragmentTable sFragmentHash;
static mutex sMulticastGroupsLock;

typedef MultiHashTable<MulticastStateHash> MulticastState;
static MulticastState* sMulticastState;

static net_protocol_module_info* sReceivingProtocol[256];
static mutex sReceivingProtocolLock;


static const char*
print_address(const in_addr* address, char* buf, size_t bufLen)
{
	unsigned int addr = ntohl(address->s_addr);

	snprintf(buf, bufLen, "%u.%u.%u.%u", (addr >> 24) & 0xff,
		(addr >> 16) & 0xff, (addr >> 8) & 0xff, addr & 0xff);

	return buf;
}


RawSocket::RawSocket(net_socket* socket)
	:
	DatagramSocket<>("ipv4 raw socket", socket)
{
}


//	#pragma mark -


FragmentPacket::FragmentPacket(const ipv4_packet_key& key)
	:
	fKey(key),
	fIndex(0),
	fReceivedLastFragment(false),
	fBytesLeft(IP_MAXPACKET)
{
	gStackModule->init_timer(&fTimer, FragmentPacket::StaleTimer, this);
}


FragmentPacket::~FragmentPacket()
{
	// cancel the kill timer
	gStackModule->set_timer(&fTimer, -1);

	// delete all fragments
	net_buffer* buffer;
	while ((buffer = fFragments.RemoveHead()) != NULL) {
		gBufferModule->free(buffer);
	}
}


status_t
FragmentPacket::AddFragment(uint16 start, uint16 end, net_buffer* buffer,
	bool lastFragment)
{
	// restart the timer
	gStackModule->set_timer(&fTimer, FRAGMENT_TIMEOUT);

	if (start >= end) {
		// invalid fragment
		return B_BAD_DATA;
	}

	// Search for a position in the list to insert the fragment

	FragmentList::ReverseIterator iterator = fFragments.GetReverseIterator();
	net_buffer* previous = NULL;
	net_buffer* next = NULL;
	while ((previous = iterator.Next()) != NULL) {
		if (previous->fragment.start <= start) {
			// The new fragment can be inserted after this one
			break;
		}

		next = previous;
	}

	// See if we already have the fragment's data

	if (previous != NULL && previous->fragment.start <= start
		&& previous->fragment.end >= end) {
		// we do, so we can just drop this fragment
		gBufferModule->free(buffer);
		return B_OK;
	}

	fIndex = buffer->index;
		// adopt the buffer's device index

	TRACE("    previous: %p, next: %p", previous, next);

	// If we have parts of the data already, truncate as needed

	if (previous != NULL && previous->fragment.end > start) {
		TRACE("    remove header %d bytes", previous->fragment.end - start);
		gBufferModule->remove_header(buffer, previous->fragment.end - start);
		start = previous->fragment.end;
	}
	if (next != NULL && end > next->fragment.start) {
		TRACE("    remove trailer %d bytes", end - next->fragment.start);
		gBufferModule->remove_trailer(buffer, end - next->fragment.start);
		end = next->fragment.start;
	}

	// Now try if we can already merge the fragments together

	// We will always keep the last buffer received, so that we can still
	// report an error (in which case we're not responsible for freeing it)

	if (previous != NULL && previous->fragment.end == start) {
		fFragments.Remove(previous);

		buffer->fragment.start = previous->fragment.start;
		buffer->fragment.end = end;

		status_t status = gBufferModule->merge(buffer, previous, false);
		TRACE("    merge previous: %s", strerror(status));
		if (status != B_OK) {
			fFragments.InsertBefore(next, previous);
			return status;
		}

		fFragments.InsertBefore(next, buffer);

		// cut down existing hole
		fBytesLeft -= end - start;

		if (lastFragment && !fReceivedLastFragment) {
			fReceivedLastFragment = true;
			fBytesLeft -= IP_MAXPACKET - end;
		}

		TRACE("    hole length: %d", (int)fBytesLeft);

		return B_OK;
	} else if (next != NULL && next->fragment.start == end) {
		net_buffer* afterNext = (net_buffer*)next->link.next;
		fFragments.Remove(next);

		buffer->fragment.start = start;
		buffer->fragment.end = next->fragment.end;

		status_t status = gBufferModule->merge(buffer, next, true);
		TRACE("    merge next: %s", strerror(status));
		if (status != B_OK) {
			// Insert "next" at its previous position
			fFragments.InsertBefore(afterNext, next);
			return status;
		}

		fFragments.InsertBefore(afterNext, buffer);

		// cut down existing hole
		fBytesLeft -= end - start;

		if (lastFragment && !fReceivedLastFragment) {
			fReceivedLastFragment = true;
			fBytesLeft -= IP_MAXPACKET - end;
		}

		TRACE("    hole length: %d", (int)fBytesLeft);

		return B_OK;
	}

	// We couldn't merge the fragments, so we need to add it as is

	TRACE("    new fragment: %p, bytes %d-%d", buffer, start, end);

	buffer->fragment.start = start;
	buffer->fragment.end = end;
	fFragments.InsertBefore(next, buffer);

	// update length of the hole, if any
	fBytesLeft -= end - start;

	if (lastFragment && !fReceivedLastFragment) {
		fReceivedLastFragment = true;
		fBytesLeft -= IP_MAXPACKET - end;
	}

	TRACE("    hole length: %d", (int)fBytesLeft);

	return B_OK;
}


/*!	Reassembles the fragments to the specified buffer \a to.
	This buffer must have been added via AddFragment() before.
*/
status_t
FragmentPacket::Reassemble(net_buffer* to)
{
	if (!IsComplete())
		return B_ERROR;

	net_buffer* buffer = NULL;

	net_buffer* fragment;
	while ((fragment = fFragments.RemoveHead()) != NULL) {
		if (buffer != NULL) {
			status_t status;
			if (to == fragment) {
				status = gBufferModule->merge(fragment, buffer, false);
				buffer = fragment;
			} else
				status = gBufferModule->merge(buffer, fragment, true);
			if (status != B_OK)
				return status;
		} else
			buffer = fragment;
	}

	if (buffer != to)
		panic("ipv4 packet reassembly did not work correctly.");

	to->index = fIndex;
		// reset the buffer's device index

	return B_OK;
}


/*static*/ void
FragmentPacket::StaleTimer(struct net_timer* timer, void* data)
{
	FragmentPacket* packet = (FragmentPacket*)data;
	TRACE("Assembling FragmentPacket %p timed out!", packet);

	MutexLocker locker(&sFragmentLock);
	sFragmentHash.Remove(packet);
	locker.Unlock();

	if (!packet->fFragments.IsEmpty()) {
		// Send error: fragment reassembly time exceeded
		sDomain->module->error_reply(NULL, packet->fFragments.First(),
			B_NET_ERROR_REASSEMBLY_TIME_EXCEEDED, NULL);
	}

	delete packet;
}


//	#pragma mark -


#ifdef TRACE_IPV4
static void
dump_ipv4_header(ipv4_header &header)
{
	struct pretty_ipv4 {
	#if B_HOST_IS_LENDIAN == 1
		uint8 a;
		uint8 b;
		uint8 c;
		uint8 d;
	#else
		uint8 d;
		uint8 c;
		uint8 b;
		uint8 a;
	#endif
	};
	struct pretty_ipv4* src = (struct pretty_ipv4*)&header.source;
	struct pretty_ipv4* dst = (struct pretty_ipv4*)&header.destination;
	dprintf("  version: %d\n", header.version);
	dprintf("  header_length: 4 * %d\n", header.header_length);
	dprintf("  service_type: %d\n", header.service_type);
	dprintf("  total_length: %d\n", header.TotalLength());
	dprintf("  id: %d\n", ntohs(header.id));
	dprintf("  fragment_offset: %d (flags: %c%c%c)\n",
		header.FragmentOffset() & IP_FRAGMENT_OFFSET_MASK,
		(header.FragmentOffset() & IP_RESERVED_FLAG) ? 'r' : '-',
		(header.FragmentOffset() & IP_DONT_FRAGMENT) ? 'd' : '-',
		(header.FragmentOffset() & IP_MORE_FRAGMENTS) ? 'm' : '-');
	dprintf("  time_to_live: %d\n", header.time_to_live);
	dprintf("  protocol: %d\n", header.protocol);
	dprintf("  checksum: %d\n", ntohs(header.checksum));
	dprintf("  source: %d.%d.%d.%d\n", src->a, src->b, src->c, src->d);
	dprintf("  destination: %d.%d.%d.%d\n", dst->a, dst->b, dst->c, dst->d);
}
#endif	// TRACE_IPV4


static int
dump_ipv4_multicast(int argc, char** argv)
{
	MulticastState::Iterator groupIterator = sMulticastState->GetIterator();

	while (groupIterator.HasNext()) {
		IPv4GroupInterface* state = groupIterator.Next();

		char addressBuffer[64];

		kprintf("%p: group <%s, %s, %s {", state, state->Interface()->name,
			print_address(&state->Address(), addressBuffer,
			sizeof(addressBuffer)),
			state->Mode() == IPv4GroupInterface::kExclude
				? "Exclude" : "Include");

		int count = 0;
		IPv4GroupInterface::AddressSet::Iterator addressIterator
			= state->Sources().GetIterator();
		while (addressIterator.HasNext()) {
			kprintf("%s%s", count > 0 ? ", " : "",
				print_address(&addressIterator.Next(),
				addressBuffer, sizeof(addressBuffer)));
			count++;
		}

		kprintf("}> sock %p\n", state->Parent()->Socket());
	}

	return 0;
}


/*!	Attempts to re-assemble fragmented packets.
	\return B_OK if everything went well; if it could reassemble the packet, \a _buffer
		will point to its buffer, otherwise, it will be \c NULL.
	\return various error codes if something went wrong (mostly B_NO_MEMORY)
*/
static status_t
reassemble_fragments(const ipv4_header &header, net_buffer** _buffer)
{
	net_buffer* buffer = *_buffer;
	status_t status;

	struct ipv4_packet_key key;
	key.source = (in_addr_t)header.source;
	key.destination = (in_addr_t)header.destination;
	key.id = header.id;
	key.protocol = header.protocol;

	// TODO: Make locking finer grained.
	MutexLocker locker(&sFragmentLock);

	FragmentPacket* packet = sFragmentHash.Lookup(key);
	if (packet == NULL) {
		// New fragment packet
		packet = new (std::nothrow) FragmentPacket(key);
		if (packet == NULL)
			return B_NO_MEMORY;

		// add packet to hash
		status = sFragmentHash.Insert(packet);
		if (status != B_OK) {
			delete packet;
			return status;
		}
	}

	uint16 fragmentOffset = header.FragmentOffset();
	uint16 start = (fragmentOffset & IP_FRAGMENT_OFFSET_MASK) << 3;
	uint16 end = start + header.TotalLength() - header.HeaderLength();
	bool lastFragment = (fragmentOffset & IP_MORE_FRAGMENTS) == 0;

	TRACE("   Received IPv4 %sfragment of size %d, offset %d.",
		lastFragment ? "last ": "", end - start, start);

	// Remove header unless this is the first fragment
	if (start != 0)
		gBufferModule->remove_header(buffer, header.HeaderLength());

	status = packet->AddFragment(start, end, buffer, lastFragment);
	if (status != B_OK)
		return status;

	if (packet->IsComplete()) {
		sFragmentHash.Remove(packet);
			// no matter if reassembling succeeds, we won't need this packet
			// anymore

		status = packet->Reassemble(buffer);
		delete packet;

		// _buffer does not change
		return status;
	}

	// This indicates that the packet is not yet complete
	*_buffer = NULL;
	return B_OK;
}


/*!	Fragments the incoming buffer and send all fragments via the specified
	\a route.
*/
static status_t
send_fragments(ipv4_protocol* protocol, struct net_route* route,
	net_buffer* buffer, uint32 mtu)
{
	TRACE_SK(protocol, "SendFragments(%" B_PRIu32 " bytes, mtu %" B_PRIu32 ")",
		buffer->size, mtu);

	NetBufferHeaderReader<ipv4_header> originalHeader(buffer);
	if (originalHeader.Status() != B_OK)
		return originalHeader.Status();

	uint16 headerLength = originalHeader->HeaderLength();
	uint32 bytesLeft = buffer->size - headerLength;
	uint32 fragmentOffset = 0;
	status_t status = B_OK;

	net_buffer* headerBuffer = gBufferModule->split(buffer, headerLength);
	if (headerBuffer == NULL)
		return B_NO_MEMORY;

	// TODO: we need to make sure ipv4_header is contiguous or
	// use another construct.
	NetBufferHeaderReader<ipv4_header> bufferHeader(headerBuffer);
	ipv4_header* header = &bufferHeader.Data();

	// Adapt MTU to be a multiple of 8 (fragment offsets can only be specified
	// this way)
	mtu -= headerLength;
	mtu &= ~7;
	TRACE("  adjusted MTU to %" B_PRIu32 ", bytesLeft %" B_PRIu32, mtu,
		bytesLeft);

	while (bytesLeft > 0) {
		uint32 fragmentLength = min_c(bytesLeft, mtu);
		bytesLeft -= fragmentLength;
		bool lastFragment = bytesLeft == 0;

		header->total_length = htons(fragmentLength + headerLength);
		header->fragment_offset = htons((lastFragment ? 0 : IP_MORE_FRAGMENTS)
			| (fragmentOffset >> 3));
		header->checksum = 0;
		header->checksum = gStackModule->checksum((uint8*)header,
			headerLength);
			// TODO: compute the checksum only for those parts that changed?

		TRACE("  send fragment of %" B_PRIu32 " bytes (%" B_PRIu32 " bytes "
			"left)", fragmentLength, bytesLeft);

		net_buffer* fragmentBuffer;
		if (!lastFragment) {
			fragmentBuffer = gBufferModule->split(buffer, fragmentLength);
			fragmentOffset += fragmentLength;
		} else
			fragmentBuffer = buffer;

		if (fragmentBuffer == NULL) {
			status = B_NO_MEMORY;
			break;
		}

		// copy header to fragment
		status = gBufferModule->prepend(fragmentBuffer, header, headerLength);

		// send fragment
		if (status == B_OK)
			status = sDatalinkModule->send_routed_data(route, fragmentBuffer);

		if (lastFragment) {
			// we don't own the last buffer, so we don't have to free it
			break;
		}

		if (status != B_OK) {
			gBufferModule->free(fragmentBuffer);
			break;
		}
	}

	gBufferModule->free(headerBuffer);
	return status;
}


status_t ipv4_receive_data(net_buffer* buffer);


/*!	Delivers the provided \a buffer to all listeners of this multicast group.
	Does not take over ownership of the buffer.
*/
static bool
deliver_multicast(net_protocol_module_info* module, net_buffer* buffer,
	bool deliverToRaw)
{
	TRACE("deliver_multicast(%p [%" B_PRIu32 " bytes])", buffer, buffer->size);
	if (module->deliver_data == NULL)
		return false;

	MutexLocker _(sMulticastGroupsLock);

	sockaddr_in* multicastAddr = (sockaddr_in*)buffer->destination;

	uint32 index = buffer->index;
	if (buffer->interface_address != NULL)
		index = buffer->interface_address->interface->index;

	MulticastState::ValueIterator it = sMulticastState->Lookup(std::make_pair(
		&multicastAddr->sin_addr, index));

	size_t count = 0;

	while (it.HasNext()) {
		IPv4GroupInterface* state = it.Next();

		ipv4_protocol* ipProtocol = state->Parent()->Socket();
		if (deliverToRaw && (ipProtocol->raw == NULL
				|| ipProtocol->socket->protocol != buffer->protocol))
			continue;

		if (state->FilterAccepts(buffer)) {
			net_protocol* protocol = ipProtocol;
			if (protocol->module != module) {
				// as multicast filters are installed with an IPv4 protocol
				// reference, we need to go and find the appropriate instance
				// related to the 'receiving protocol' with module 'module'.
				protocol = ipProtocol->socket->first_protocol;

				while (protocol != NULL && protocol->module != module)
					protocol = protocol->next;
			}

			if (protocol != NULL) {
				module->deliver_data(protocol, buffer);
				count++;
			}
		}
	}

	return count > 0;
}


/*!	Delivers the buffer to all listening raw sockets without taking ownership of
	the provided \a buffer.
	Returns \c true if there was any receiver, \c false if not.
*/
static bool
raw_receive_data(net_buffer* buffer)
{
	MutexLocker locker(sRawSocketsLock);

	if (sRawSockets.IsEmpty())
		return false;

	TRACE("RawReceiveData(%i)", buffer->protocol);

	if ((buffer->msg_flags & MSG_MCAST) != 0) {
		// we need to call deliver_multicast here separately as
		// buffer still has the IP header, and it won't in the
		// next call. This isn't very optimized but works for now.
		// A better solution would be to hold separate hash tables
		// and lists for RAW and non-RAW sockets.
		return deliver_multicast(&gIPv4Module, buffer, true);
	}

	RawSocketList::Iterator iterator = sRawSockets.GetIterator();
	size_t count = 0;

	while (iterator.HasNext()) {
		RawSocket* raw = iterator.Next();

		if (raw->Socket()->protocol == buffer->protocol) {
			raw->EnqueueClone(buffer);
			count++;
		}
	}

	return count > 0;
}


static inline sockaddr*
fill_sockaddr_in(sockaddr_in* target, in_addr_t address)
{
	target->sin_family = AF_INET;
	target->sin_len = sizeof(sockaddr_in);
	target->sin_port = 0;
	target->sin_addr.s_addr = address;
	return (sockaddr*)target;
}


static status_t
get_int_option(void* target, size_t length, int value)
{
	if (length != sizeof(int))
		return B_BAD_VALUE;

	return user_memcpy(target, &value, sizeof(int));
}


static status_t
get_char_int_option(void* target, size_t length, int value)
{
	if (length == sizeof(int))
		return user_memcpy(target, &value, sizeof(int));
	if (length == sizeof(unsigned char)) {
		unsigned char uvalue = value;
		return user_memcpy(target, &uvalue, sizeof(uvalue));
	}
	return B_BAD_VALUE;
}


template<typename Type> static status_t
set_int_option(Type &target, const void* _value, size_t length)
{
	int value;

	if (length != sizeof(int))
		return B_BAD_VALUE;

	if (user_memcpy(&value, _value, sizeof(int)) != B_OK)
		return B_BAD_ADDRESS;

	target = value;
	return B_OK;
}


template<typename Type> static status_t
set_char_int_option(Type &target, const void* _value, size_t length)
{
	if (length == sizeof(int)) {
		int value;
		if (user_memcpy(&value, _value, sizeof(int)) != B_OK)
			return B_BAD_ADDRESS;
		if (value > 255)
			return B_BAD_VALUE;
		target = value;
		return B_OK;
	}
	if (length == sizeof(unsigned char)) {
		unsigned char value;
		if (user_memcpy(&value, _value, sizeof(value)) != B_OK)
			return B_BAD_ADDRESS;

		target = value;
		return B_OK;
	}
	return B_BAD_VALUE;
}


static net_protocol_module_info*
receiving_protocol(uint8 protocol)
{
	net_protocol_module_info* module = sReceivingProtocol[protocol];
	if (module != NULL)
		return module;

	MutexLocker locker(sReceivingProtocolLock);

	module = sReceivingProtocol[protocol];
	if (module != NULL)
		return module;

	if (gStackModule->get_domain_receiving_protocol(sDomain, protocol,
			&module) == B_OK)
		sReceivingProtocol[protocol] = module;

	return module;
}


// #pragma mark - multicast


status_t
IPv4Multicast::JoinGroup(IPv4GroupInterface* state)
{
	MutexLocker _(sMulticastGroupsLock);

	sockaddr_in groupAddr;
	status_t status = sDatalinkModule->join_multicast(state->Interface(),
		sDomain, fill_sockaddr_in(&groupAddr, state->Address().s_addr));
	if (status != B_OK)
		return status;

	sMulticastState->Insert(state);
	return B_OK;
}


status_t
IPv4Multicast::LeaveGroup(IPv4GroupInterface* state)
{
	MutexLocker _(sMulticastGroupsLock);

	sMulticastState->Remove(state);

	sockaddr_in groupAddr;
	return sDatalinkModule->leave_multicast(state->Interface(), sDomain,
		fill_sockaddr_in(&groupAddr, state->Address().s_addr));
}


static status_t
ipv4_delta_group(IPv4GroupInterface* group, int option,
	net_interface* interface, const in_addr* sourceAddr)
{
	switch (option) {
		case IP_ADD_MEMBERSHIP:
			return group->Add();
		case IP_DROP_MEMBERSHIP:
			return group->Drop();
		case IP_BLOCK_SOURCE:
			return group->BlockSource(*sourceAddr);
		case IP_UNBLOCK_SOURCE:
			return group->UnblockSource(*sourceAddr);
		case IP_ADD_SOURCE_MEMBERSHIP:
			return group->AddSSM(*sourceAddr);
		case IP_DROP_SOURCE_MEMBERSHIP:
			return group->DropSSM(*sourceAddr);
	}

	return B_ERROR;
}


static status_t
ipv4_delta_membership(ipv4_protocol* protocol, int option,
	net_interface* interface, const in_addr* groupAddr,
	const in_addr* sourceAddr)
{
	IPv4MulticastFilter& filter = protocol->multicast_filter;
	IPv4GroupInterface* state = NULL;
	status_t status = B_OK;

	switch (option) {
		case IP_ADD_MEMBERSHIP:
		case IP_ADD_SOURCE_MEMBERSHIP:
			status = filter.GetState(*groupAddr, interface, state, true);
			break;

		case IP_DROP_MEMBERSHIP:
		case IP_BLOCK_SOURCE:
		case IP_UNBLOCK_SOURCE:
		case IP_DROP_SOURCE_MEMBERSHIP:
			filter.GetState(*groupAddr, interface, state, false);
			if (state == NULL) {
				if (option == IP_DROP_MEMBERSHIP
					|| option == IP_DROP_SOURCE_MEMBERSHIP)
					return EADDRNOTAVAIL;

				return B_BAD_VALUE;
			}
			break;
	}

	if (status != B_OK)
		return status;

	status = ipv4_delta_group(state, option, interface, sourceAddr);
	filter.ReturnState(state);
	return status;
}


static int
generic_to_ipv4(int option)
{
	switch (option) {
		case MCAST_JOIN_GROUP:
			return IP_ADD_MEMBERSHIP;
		case MCAST_JOIN_SOURCE_GROUP:
			return IP_ADD_SOURCE_MEMBERSHIP;
		case MCAST_LEAVE_GROUP:
			return IP_DROP_MEMBERSHIP;
		case MCAST_BLOCK_SOURCE:
			return IP_BLOCK_SOURCE;
		case MCAST_UNBLOCK_SOURCE:
			return IP_UNBLOCK_SOURCE;
		case MCAST_LEAVE_SOURCE_GROUP:
			return IP_DROP_SOURCE_MEMBERSHIP;
	}

	return -1;
}


static net_interface*
get_multicast_interface(ipv4_protocol* protocol, const in_addr* address)
{
	// TODO: this is broken and leaks references
	sockaddr_in groupAddr;
	net_route* route = sDatalinkModule->get_route(sDomain,
		fill_sockaddr_in(&groupAddr, address ? address->s_addr : INADDR_ANY));
	if (route == NULL)
		return NULL;

	return route->interface_address->interface;
}


static status_t
ipv4_delta_membership(ipv4_protocol* protocol, int option,
	in_addr* interfaceAddr, in_addr* groupAddr, in_addr* sourceAddr)
{
	net_interface* interface = NULL;

	if (interfaceAddr->s_addr == INADDR_ANY) {
		interface = get_multicast_interface(protocol, groupAddr);
	} else {
		sockaddr_in address;
		interface = sDatalinkModule->get_interface_with_address(
			fill_sockaddr_in(&address, interfaceAddr->s_addr));
	}

	if (interface == NULL)
		return B_DEVICE_NOT_FOUND;

	return ipv4_delta_membership(protocol, option, interface,
		groupAddr, sourceAddr);
}


static status_t
ipv4_generic_delta_membership(ipv4_protocol* protocol, int option,
	uint32 index, const sockaddr_storage* _groupAddr,
	const sockaddr_storage* _sourceAddr)
{
	if (_groupAddr->ss_family != AF_INET
		|| (_sourceAddr != NULL && _sourceAddr->ss_family != AF_INET))
		return B_BAD_VALUE;

	const in_addr* groupAddr = &((const sockaddr_in*)_groupAddr)->sin_addr;

	// TODO: this is broken and leaks references
	net_interface* interface;
	if (index == 0)
		interface = get_multicast_interface(protocol, groupAddr);
	else
		interface = sDatalinkModule->get_interface(sDomain, index);

	if (interface == NULL)
		return B_DEVICE_NOT_FOUND;

	const in_addr* sourceAddr = NULL;
	if (_sourceAddr != NULL)
		sourceAddr = &((const sockaddr_in*)_sourceAddr)->sin_addr;

	return ipv4_delta_membership(protocol, generic_to_ipv4(option), interface,
		groupAddr, sourceAddr);
}


//	#pragma mark - module interface


net_protocol*
ipv4_init_protocol(net_socket* socket)
{
	ipv4_protocol* protocol = new (std::nothrow) ipv4_protocol();
	if (protocol == NULL)
		return NULL;

	protocol->raw = NULL;
	protocol->service_type = 0;
	protocol->time_to_live = kDefaultTTL;
	protocol->multicast_time_to_live = kDefaultMulticastTTL;
	protocol->multicast_loopback = kDefaultMulticastLoopback;
	protocol->flags = 0;
	protocol->multicast_address = NULL;
	return protocol;
}


status_t
ipv4_uninit_protocol(net_protocol* _protocol)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;

	delete protocol;

	return B_OK;
}


/*!	Since open() is only called on the top level protocol, when we get here
	it means we are on a SOCK_RAW socket.
*/
status_t
ipv4_open(net_protocol* _protocol)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;

	// Only root may open raw sockets
	if (geteuid() != 0)
		return B_NOT_ALLOWED;

	RawSocket* raw = new (std::nothrow) RawSocket(protocol->socket);
	if (raw == NULL)
		return B_NO_MEMORY;

	status_t status = raw->InitCheck();
	if (status != B_OK) {
		delete raw;
		return status;
	}

	TRACE_SK(protocol, "Open()");

	protocol->raw = raw;

	MutexLocker locker(sRawSocketsLock);
	sRawSockets.Add(raw);
	return B_OK;
}


status_t
ipv4_close(net_protocol* _protocol)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;
	RawSocket* raw = protocol->raw;
	if (raw == NULL)
		return B_ERROR;

	TRACE_SK(protocol, "Close()");

	MutexLocker locker(sRawSocketsLock);
	sRawSockets.Remove(raw);
	delete raw;
	protocol->raw = NULL;

	return B_OK;
}


status_t
ipv4_free(net_protocol* protocol)
{
	return B_OK;
}


status_t
ipv4_connect(net_protocol* _protocol, const struct sockaddr* address)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;
	RawSocket* raw = protocol->raw;
	if (raw == NULL)
		return B_ERROR;
	if (address->sa_len != sizeof(struct sockaddr_in))
		return B_BAD_VALUE;
	if (address->sa_family != AF_INET)
		return EAFNOSUPPORT;

	memcpy(&protocol->socket->peer, address, sizeof(struct sockaddr_in));
	sSocketModule->set_connected(protocol->socket);

	return B_OK;
}


status_t
ipv4_accept(net_protocol* protocol, struct net_socket** _acceptedSocket)
{
	return B_NOT_SUPPORTED;
}


status_t
ipv4_control(net_protocol* _protocol, int level, int option, void* value,
	size_t* _length)
{
	if ((level & LEVEL_MASK) != IPPROTO_IP)
		return sDatalinkModule->control(sDomain, option, value, _length);

	return B_BAD_VALUE;
}


status_t
ipv4_getsockopt(net_protocol* _protocol, int level, int option, void* value,
	int* _length)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;

	if (level == IPPROTO_IP) {
		bool isDgramOrRaw = protocol->socket->type == SOCK_DGRAM
			|| protocol->socket->type == SOCK_RAW;
		if (option == IP_HDRINCL) {
			return get_int_option(value, *_length,
				(protocol->flags & IP_FLAG_HEADER_INCLUDED) != 0);
		}
		if (option == IP_RECVDSTADDR) {
			return get_int_option(value, *_length,
				(protocol->flags & IP_FLAG_RECEIVE_DEST_ADDR) != 0);
		}
		if (option == IP_TTL)
			return get_int_option(value, *_length, protocol->time_to_live);
		if (option == IP_TOS)
			return get_int_option(value, *_length, protocol->service_type);
		if (option == IP_MULTICAST_IF) {
			if (*_length != sizeof(struct in_addr))
				return B_BAD_VALUE;
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;
			struct sockaddr_in defaultAddress;
			defaultAddress.sin_addr.s_addr = htonl(INADDR_ANY);
			struct sockaddr_in* address =
				(struct sockaddr_in*)protocol->multicast_address;
			if (address == NULL)
				address = &defaultAddress;
			if (user_memcpy(value, &address->sin_addr, sizeof(struct in_addr))
					!= B_OK) {
				return B_BAD_ADDRESS;
			}
			return B_OK;
		}
		if (option == IP_MULTICAST_TTL) {
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;
			return get_char_int_option(value, *_length,
				protocol->multicast_time_to_live);
		}
		if (option == IP_MULTICAST_LOOP) {
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;
			return get_char_int_option(value, *_length,
				protocol->multicast_loopback ? 1 : 0);
		}
		if (option == IP_ADD_MEMBERSHIP
			|| option == IP_DROP_MEMBERSHIP
			|| option == IP_BLOCK_SOURCE
			|| option == IP_UNBLOCK_SOURCE
			|| option == IP_ADD_SOURCE_MEMBERSHIP
			|| option == IP_DROP_SOURCE_MEMBERSHIP
			|| option == MCAST_JOIN_GROUP
			|| option == MCAST_LEAVE_GROUP
			|| option == MCAST_BLOCK_SOURCE
			|| option == MCAST_UNBLOCK_SOURCE
			|| option == MCAST_JOIN_SOURCE_GROUP
			|| option == MCAST_LEAVE_SOURCE_GROUP) {
			// RFC 3678, Section 4.1:
			// ``An error of EOPNOTSUPP is returned if these options are
			// used with getsockopt().''
			return B_NOT_SUPPORTED;
		}

		dprintf("IPv4::getsockopt(): get unknown option: %d\n", option);
		return ENOPROTOOPT;
	}

	return sSocketModule->get_option(protocol->socket, level, option, value,
		_length);
}


status_t
ipv4_setsockopt(net_protocol* _protocol, int level, int option,
	const void* value, int length)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;

	if (level == IPPROTO_IP) {
		bool isDgramOrRaw = protocol->socket->type == SOCK_DGRAM
			|| protocol->socket->type == SOCK_RAW;
		if (option == IP_HDRINCL) {
			int headerIncluded;
			if (length != sizeof(int))
				return B_BAD_VALUE;
			if (user_memcpy(&headerIncluded, value, sizeof(headerIncluded))
					!= B_OK)
				return B_BAD_ADDRESS;

			if (headerIncluded)
				protocol->flags |= IP_FLAG_HEADER_INCLUDED;
			else
				protocol->flags &= ~IP_FLAG_HEADER_INCLUDED;

			return B_OK;
		}
		if (option == IP_RECVDSTADDR) {
			int getAddress;
			if (length != sizeof(int))
				return B_BAD_VALUE;
			if (user_memcpy(&getAddress, value, sizeof(int)) != B_OK)
				return B_BAD_ADDRESS;

			if (getAddress && (protocol->socket->type == SOCK_DGRAM
					|| protocol->socket->type == SOCK_RAW))
				protocol->flags |= IP_FLAG_RECEIVE_DEST_ADDR;
			else
				protocol->flags &= ~IP_FLAG_RECEIVE_DEST_ADDR;

			return B_OK;
		}
		if (option == IP_TTL)
			return set_int_option(protocol->time_to_live, value, length);
		if (option == IP_TOS)
			return set_int_option(protocol->service_type, value, length);
		if (option == IP_MULTICAST_IF) {
			if (length != sizeof(struct in_addr))
				return B_BAD_VALUE;
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;

			struct sockaddr_in* address = new (std::nothrow) sockaddr_in;
			if (address == NULL)
				return B_NO_MEMORY;

			struct in_addr sin_addr;
			if (user_memcpy(&sin_addr, value, sizeof(struct in_addr))
					!= B_OK) {
				delete address;
				return B_BAD_ADDRESS;
			}
			fill_sockaddr_in(address, sin_addr.s_addr);

			// Using INADDR_ANY to remove the previous setting.
			if (address->sin_addr.s_addr == htonl(INADDR_ANY)) {
				delete address;
				delete protocol->multicast_address;
				protocol->multicast_address = NULL;
				return B_OK;
			}

			struct net_interface* interface
				= sDatalinkModule->get_interface_with_address(
					(sockaddr*)address);
			if (interface == NULL) {
				delete address;
				return EADDRNOTAVAIL;
			}

			delete protocol->multicast_address;
			protocol->multicast_address = (struct sockaddr*)address;

			sDatalinkModule->put_interface(interface);
			return B_OK;
		}
		if (option == IP_MULTICAST_TTL) {
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;
			return set_char_int_option(protocol->multicast_time_to_live, value,
				length);
		}
		if (option == IP_MULTICAST_LOOP) {
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;
			uint8 multicast_loopback;
			status_t status = set_char_int_option(multicast_loopback, value,
				length);
			if (status == B_OK)
				protocol->multicast_loopback = multicast_loopback != 0;
			return status;
		}
		if (option == IP_ADD_MEMBERSHIP || option == IP_DROP_MEMBERSHIP) {
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;
			ip_mreq mreq;
			if (length != sizeof(ip_mreq))
				return B_BAD_VALUE;
			if (user_memcpy(&mreq, value, sizeof(ip_mreq)) != B_OK)
				return B_BAD_ADDRESS;

			return ipv4_delta_membership(protocol, option, &mreq.imr_interface,
				&mreq.imr_multiaddr, NULL);
		}
		if (option == IP_BLOCK_SOURCE
			|| option == IP_UNBLOCK_SOURCE
			|| option == IP_ADD_SOURCE_MEMBERSHIP
			|| option == IP_DROP_SOURCE_MEMBERSHIP) {
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;
			ip_mreq_source mreq;
			if (length != sizeof(ip_mreq_source))
				return B_BAD_VALUE;
			if (user_memcpy(&mreq, value, sizeof(ip_mreq_source)) != B_OK)
				return B_BAD_ADDRESS;

			return ipv4_delta_membership(protocol, option, &mreq.imr_interface,
				&mreq.imr_multiaddr, &mreq.imr_sourceaddr);
		}
		if (option == MCAST_LEAVE_GROUP || option == MCAST_JOIN_GROUP) {
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;
			group_req greq;
			if (length != sizeof(group_req))
				return B_BAD_VALUE;
			if (user_memcpy(&greq, value, sizeof(group_req)) != B_OK)
				return B_BAD_ADDRESS;

			return ipv4_generic_delta_membership(protocol, option,
				greq.gr_interface, &greq.gr_group, NULL);
		}
		if (option == MCAST_BLOCK_SOURCE
			|| option == MCAST_UNBLOCK_SOURCE
			|| option == MCAST_JOIN_SOURCE_GROUP
			|| option == MCAST_LEAVE_SOURCE_GROUP) {
			if (!isDgramOrRaw)
				return B_NOT_SUPPORTED;
			group_source_req greq;
			if (length != sizeof(group_source_req))
				return B_BAD_VALUE;
			if (user_memcpy(&greq, value, sizeof(group_source_req)) != B_OK)
				return B_BAD_ADDRESS;

			return ipv4_generic_delta_membership(protocol, option,
				greq.gsr_interface, &greq.gsr_group, &greq.gsr_source);
		}

		dprintf("IPv4::setsockopt(): set unknown option: %d\n", option);
		return ENOPROTOOPT;
	}

	return sSocketModule->set_option(protocol->socket, level, option,
		value, length);
}


status_t
ipv4_bind(net_protocol* protocol, const struct sockaddr* address)
{
	if (address->sa_family != AF_INET)
		return EAFNOSUPPORT;

	// only INADDR_ANY and addresses of local interfaces are accepted:
	if (((sockaddr_in*)address)->sin_addr.s_addr == INADDR_ANY
		|| IN_MULTICAST(ntohl(((sockaddr_in*)address)->sin_addr.s_addr))
		|| sDatalinkModule->is_local_address(sDomain, address, NULL, NULL)) {
		memcpy(&protocol->socket->address, address, sizeof(struct sockaddr_in));
		protocol->socket->address.ss_len = sizeof(struct sockaddr_in);
			// explicitly set length, as our callers can't be trusted to
			// always provide the correct length!
		return B_OK;
	}

	return EADDRNOTAVAIL;
		// address is unknown on this host
}


status_t
ipv4_unbind(net_protocol* protocol, struct sockaddr* address)
{
	// nothing to do here
	return B_OK;
}


status_t
ipv4_listen(net_protocol* protocol, int count)
{
	return B_NOT_SUPPORTED;
}


status_t
ipv4_shutdown(net_protocol* protocol, int direction)
{
	return B_NOT_SUPPORTED;
}


status_t
ipv4_send_routed_data(net_protocol* _protocol, struct net_route* route,
	net_buffer* buffer)
{
	if (route == NULL)
		return B_BAD_VALUE;

	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;
	net_interface_address* interfaceAddress = route->interface_address;
	net_interface* interface = interfaceAddress->interface;

	TRACE_SK(protocol, "SendRoutedData(%p, %p [%" B_PRIu32 " bytes])", route,
		buffer, buffer->size);

	sockaddr_in& source = *(sockaddr_in*)buffer->source;
	sockaddr_in& destination = *(sockaddr_in*)buffer->destination;
	sockaddr_in* broadcastAddress = (sockaddr_in*)interfaceAddress->destination;

	bool checksumNeeded = true;
	bool headerIncluded = false;
	if (protocol != NULL)
		headerIncluded = (protocol->flags & IP_FLAG_HEADER_INCLUDED) != 0;

	buffer->msg_flags &= ~(MSG_BCAST | MSG_MCAST);

	if (destination.sin_addr.s_addr == INADDR_ANY)
		return EDESTADDRREQ;

	if ((interface->device->flags & IFF_BROADCAST) != 0
		&& (destination.sin_addr.s_addr == INADDR_BROADCAST
			|| (broadcastAddress != NULL && destination.sin_addr.s_addr
					== broadcastAddress->sin_addr.s_addr))) {
		if (protocol && !(protocol->socket->options & SO_BROADCAST))
			return B_BAD_VALUE;
		buffer->msg_flags |= MSG_BCAST;
	} else if (IN_MULTICAST(ntohl(destination.sin_addr.s_addr)))
		buffer->msg_flags |= MSG_MCAST;

	// Add IP header (if needed)

	if (!headerIncluded) {
		NetBufferPrepend<ipv4_header> header(buffer);
		if (header.Status() != B_OK)
			return header.Status();

		header->version = IPV4_VERSION;
		header->header_length = sizeof(ipv4_header) / 4;
		header->service_type = protocol ? protocol->service_type : 0;
		header->total_length = htons(buffer->size);
		header->id = htons(atomic_add(&sPacketID, 1));
		header->fragment_offset = 0;
		if (protocol) {
			header->time_to_live = (buffer->msg_flags & MSG_MCAST) != 0
				? protocol->multicast_time_to_live : protocol->time_to_live;
		} else {
			header->time_to_live = (buffer->msg_flags & MSG_MCAST) != 0
				? kDefaultMulticastTTL : kDefaultTTL;
		}
		header->protocol = protocol
			? protocol->socket->protocol : buffer->protocol;
		header->checksum = 0;

		header->source = source.sin_addr.s_addr;
		header->destination = destination.sin_addr.s_addr;

		TRACE_ONLY(dump_ipv4_header(*header));
	} else {
		// if IP_HDRINCL, check if the source address is set
		NetBufferHeaderReader<ipv4_header> header(buffer);
		if (header.Status() != B_OK)
			return header.Status();

		if (header->source == 0) {
			header->source = source.sin_addr.s_addr;
			header->checksum = 0;
			header.Sync();
		} else
			checksumNeeded = false;

		TRACE("  Header was already supplied:");
		TRACE_ONLY(dump_ipv4_header(*header));
	}

	if (buffer->size > 0xffff)
		return EMSGSIZE;

	if (checksumNeeded) {
		*IPChecksumField(buffer) = gBufferModule->checksum(buffer, 0,
			sizeof(ipv4_header), true);
		buffer->buffer_flags |= NET_BUFFER_L3_CHECKSUM_VALID;
	}

	if ((buffer->msg_flags & MSG_MCAST) != 0
		&& (protocol != NULL && protocol->multicast_loopback)) {
		// copy an IP multicast packet to the input queue of the loopback
		// interface
		net_buffer *loopbackBuffer = gBufferModule->duplicate(buffer);
		if (loopbackBuffer == NULL)
			return B_NO_MEMORY;
		status_t status = B_ERROR;

		// get the IPv4 loopback address
		struct sockaddr loopbackAddress;
		gIPv4AddressModule.get_loopback_address(&loopbackAddress);

		// get the matching interface address if any
		net_interface_address* address =
			sDatalinkModule->get_interface_address(&loopbackAddress);
		if (address == NULL || (address->interface->flags & IFF_UP) == 0) {
			sDatalinkModule->put_interface_address(address);
		} else {
			sDatalinkModule->put_interface_address(
				loopbackBuffer->interface_address);
			loopbackBuffer->interface_address = address;
			status = ipv4_receive_data(loopbackBuffer);
		}

		if (status != B_OK)
			gBufferModule->free(loopbackBuffer);
	}

	TRACE_SK(protocol, "  SendRoutedData(): header chksum: %" B_PRIu32
		", buffer checksum: %" B_PRIu32,
		gBufferModule->checksum(buffer, 0, sizeof(ipv4_header), true),
		gBufferModule->checksum(buffer, 0, buffer->size, true));

	TRACE_SK(protocol, "  SendRoutedData(): destination: %08x",
		ntohl(destination.sin_addr.s_addr));

	uint32 mtu = route->mtu ? route->mtu : interface->device->mtu;
	if (buffer->size > mtu) {
		// we need to fragment the packet
		return send_fragments(protocol, route, buffer, mtu);
	}

	return sDatalinkModule->send_routed_data(route, buffer);
}


status_t
ipv4_send_data(net_protocol* _protocol, net_buffer* buffer)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;

	TRACE_SK(protocol, "SendData(%p [%" B_PRIu32 " bytes])", buffer,
		buffer->size);

	if (protocol != NULL && (protocol->flags & IP_FLAG_HEADER_INCLUDED)) {
		if (buffer->size < sizeof(ipv4_header))
			return B_BAD_VALUE;

		sockaddr_in* source = (sockaddr_in*)buffer->source;
		sockaddr_in* destination = (sockaddr_in*)buffer->destination;

		fill_sockaddr_in(source, *NetBufferField<in_addr_t,
			offsetof(ipv4_header, source)>(buffer));
		fill_sockaddr_in(destination, *NetBufferField<in_addr_t,
			offsetof(ipv4_header, destination)>(buffer));
	}

	// handle IP_MULTICAST_IF
	if (IN_MULTICAST(ntohl(
			((sockaddr_in*)buffer->destination)->sin_addr.s_addr))
		&& protocol != NULL && protocol->multicast_address != NULL) {
		net_interface_address* address = sDatalinkModule->get_interface_address(
			protocol->multicast_address);
		if (address == NULL || (address->interface->flags & IFF_UP) == 0) {
			sDatalinkModule->put_interface_address(address);
			return EADDRNOTAVAIL;
		}

		sDatalinkModule->put_interface_address(buffer->interface_address);
		buffer->interface_address = address;
			// the buffer takes over ownership of the address

		net_route* route = sDatalinkModule->get_route(sDomain, address->local);
		if (route == NULL)
			return ENETUNREACH;

		return sDatalinkModule->send_routed_data(route, buffer);
	}

	return sDatalinkModule->send_data(protocol, sDomain, buffer);
}


ssize_t
ipv4_send_avail(net_protocol* protocol)
{
	return B_ERROR;
}


status_t
ipv4_read_data(net_protocol* _protocol, size_t numBytes, uint32 flags,
	net_buffer** _buffer)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;
	RawSocket* raw = protocol->raw;
	if (raw == NULL)
		return B_ERROR;

	TRACE_SK(protocol, "ReadData(%lu, 0x%" B_PRIx32 ")", numBytes, flags);

	return raw->Dequeue(flags, _buffer);
}


ssize_t
ipv4_read_avail(net_protocol* _protocol)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;
	RawSocket* raw = protocol->raw;
	if (raw == NULL)
		return B_ERROR;

	return raw->AvailableData();
}


struct net_domain*
ipv4_get_domain(net_protocol* protocol)
{
	return sDomain;
}


size_t
ipv4_get_mtu(net_protocol* protocol, const struct sockaddr* address)
{
	net_route* route = sDatalinkModule->get_route(sDomain, address);
	if (route == NULL)
		return 0;

	size_t mtu;
	if (route->mtu != 0)
		mtu = route->mtu;
	else
		mtu = route->interface_address->interface->device->mtu;

	sDatalinkModule->put_route(sDomain, route);

	if (mtu > 0xffff)
		mtu = 0xffff;

	return mtu - sizeof(ipv4_header);
}


status_t
ipv4_receive_data(net_buffer* buffer)
{
	TRACE("ipv4_receive_data(%p [%" B_PRIu32 " bytes])", buffer, buffer->size);

	uint16 headerLength = 0;
	{
	NetBufferHeaderReader<ipv4_header> bufferHeader(buffer);
	if (bufferHeader.Status() != B_OK)
		return bufferHeader.Status();

	ipv4_header& header = bufferHeader.Data();
	TRACE_ONLY(dump_ipv4_header(header));

	if (header.version != IPV4_VERSION)
		return B_BAD_TYPE;

	uint16 packetLength = header.TotalLength();
	headerLength = header.HeaderLength();
	if (packetLength > buffer->size
		|| headerLength < sizeof(ipv4_header))
		return B_BAD_DATA;

	if ((buffer->buffer_flags & NET_BUFFER_L3_CHECKSUM_VALID) == 0) {
		if (gBufferModule->checksum(buffer, 0, headerLength, true) != 0)
			return B_BAD_DATA;
	}

	// lower layers notion of broadcast or multicast have no relevance to us
	// other than deciding whether to send an ICMP error
	bool wasMulticast = (buffer->msg_flags & (MSG_BCAST | MSG_MCAST)) != 0;
	bool notForUs = false;
	buffer->msg_flags &= ~(MSG_BCAST | MSG_MCAST);

	sockaddr_in destination;
	fill_sockaddr_in(&destination, header.destination);

	if (header.destination == INADDR_BROADCAST) {
		buffer->msg_flags |= MSG_BCAST;

		// Find first interface with a matching family
		if (!sDatalinkModule->is_local_link_address(sDomain, true,
				buffer->destination, &buffer->interface_address))
			notForUs = !wasMulticast;
	} else if (IN_MULTICAST(ntohl(header.destination))) {
		buffer->msg_flags |= MSG_MCAST;
	} else {
		uint32 matchedAddressType = 0;

		// test if the packet is really for us
		if (!sDatalinkModule->is_local_address(sDomain, (sockaddr*)&destination,
				&buffer->interface_address, &matchedAddressType)
			&& !sDatalinkModule->is_local_link_address(sDomain, true,
				buffer->destination, &buffer->interface_address)) {
			// if the buffer was a link layer multicast, regard it as a
			// broadcast, and let the upper levels decide what to do with it
			if (wasMulticast)
				buffer->msg_flags |= MSG_BCAST;
			else
				notForUs = true;
		} else {
			// copy over special address types (MSG_BCAST or MSG_MCAST):
			buffer->msg_flags |= matchedAddressType;
		}
	}

	// set net_buffer's source/destination address
	fill_sockaddr_in((struct sockaddr_in*)buffer->source, header.source);
	memcpy(buffer->destination, &destination, sizeof(sockaddr_in));

	buffer->protocol = header.protocol;

	if (notForUs) {
		TRACE("  ipv4_receive_data(): packet was not for us %x -> %x",
			ntohl(header.source), ntohl(header.destination));

		if (!wasMulticast) {
			// Send ICMP error: Host unreachable
			sDomain->module->error_reply(NULL, buffer, B_NET_ERROR_UNREACH_HOST,
				NULL);
		}

		return B_ERROR;
	}

	// remove any trailing/padding data
	status_t status = gBufferModule->trim(buffer, packetLength);
	if (status != B_OK)
		return status;

	// check for fragmentation
	uint16 fragmentOffset = header.FragmentOffset();
	if ((fragmentOffset & IP_MORE_FRAGMENTS) != 0
		|| (fragmentOffset & IP_FRAGMENT_OFFSET_MASK) != 0) {
		// this is a fragment
		TRACE("  ipv4_receive_data(): Found a Fragment!");
		status = reassemble_fragments(header, &buffer);
		TRACE("  ipv4_receive_data():  -> %s", strerror(status));
		if (status != B_OK)
			return status;

		if (buffer == NULL) {
			// buffer was put into fragment packet
			TRACE("  ipv4_receive_data(): Not yet assembled.");
			return B_OK;
		}
	}

	// Since the buffer might have been changed (reassembled fragment)
	// we must no longer access bufferHeader or header anymore after
	// this point
	}

	bool rawDelivered = raw_receive_data(buffer);

	// Preserve the ipv4 header for ICMP processing
	gBufferModule->store_header(buffer);
	gBufferModule->remove_header(buffer, headerLength);
		// the header is of variable size and may include IP options
		// (TODO: that we ignore for now)

	net_protocol_module_info* module = receiving_protocol(buffer->protocol);
	if (module == NULL) {
		// no handler for this packet
		if (!rawDelivered) {
			sDomain->module->error_reply(NULL, buffer,
				B_NET_ERROR_UNREACH_PROTOCOL, NULL);
		}
		return EAFNOSUPPORT;
	}

	if ((buffer->msg_flags & MSG_MCAST) != 0) {
		// Unfortunately historical reasons dictate that the IP multicast
		// model be a little different from the unicast one. We deliver
		// this frame directly to all sockets registered with interest
		// for this multicast group.
		deliver_multicast(module, buffer, false);
		gBufferModule->free(buffer);
		return B_OK;
	}

	return module->receive_data(buffer);
}


status_t
ipv4_deliver_data(net_protocol* _protocol, net_buffer* buffer)
{
	ipv4_protocol* protocol = (ipv4_protocol*)_protocol;

	if (protocol->raw == NULL)
		return B_ERROR;

	return protocol->raw->EnqueueClone(buffer);
}


status_t
ipv4_error_received(net_error error, net_buffer* buffer)
{
	TRACE("  ipv4_error_received(error %d, buffer %p [%" B_PRIu32 " bytes])",
		(int)error, buffer, buffer->size);

	NetBufferHeaderReader<ipv4_header> bufferHeader(buffer);
	if (bufferHeader.Status() != B_OK)
		return bufferHeader.Status();

	ipv4_header& header = bufferHeader.Data();
	TRACE_ONLY(dump_ipv4_header(header));

	// We do not check the packet length, as we usually only get a part of it
	uint16 headerLength = header.HeaderLength();
	if (header.version != IPV4_VERSION
		|| headerLength < sizeof(ipv4_header)
		|| gBufferModule->checksum(buffer, 0, headerLength, true) != 0)
		return B_BAD_DATA;

	// Restore addresses of the original buffer

	// lower layers notion of broadcast or multicast have no relevance to us
	// TODO: they actually have when deciding whether to send an ICMP error
	buffer->msg_flags &= ~(MSG_BCAST | MSG_MCAST);

	fill_sockaddr_in((struct sockaddr_in*)buffer->source, header.source);
	fill_sockaddr_in((struct sockaddr_in*)buffer->destination,
		header.destination);

	if (header.destination == INADDR_BROADCAST)
		buffer->msg_flags |= MSG_BCAST;
	else if (IN_MULTICAST(ntohl(header.destination)))
		buffer->msg_flags |= MSG_MCAST;

	// test if the packet is really from us
	if (!sDatalinkModule->is_local_address(sDomain, buffer->source, NULL,
			NULL)) {
		TRACE("  ipv4_error_received(): packet was not for us %x -> %x",
			ntohl(header.source), ntohl(header.destination));
		return B_ERROR;
	}

	buffer->protocol = header.protocol;

	bufferHeader.Remove(headerLength);

	net_protocol_module_info* protocol = receiving_protocol(buffer->protocol);
	if (protocol == NULL)
		return B_ERROR;

	// propagate error
	return protocol->error_received(error, buffer);
}


status_t
ipv4_error_reply(net_protocol* protocol, net_buffer* cause, net_error error,
	net_error_data* errorData)
{
	// Directly obtain the ICMP protocol module
	net_protocol_module_info* icmp = receiving_protocol(IPPROTO_ICMP);
	if (icmp == NULL)
		return B_ERROR;

	return icmp->error_reply(protocol, cause, error, errorData);
}


ssize_t
ipv4_process_ancillary_data_no_container(net_protocol* protocol,
	net_buffer* buffer, void* msgControl, size_t msgControlLen)
{
	ssize_t bytesWritten = 0;

	if ((((ipv4_protocol*)protocol)->flags & IP_FLAG_RECEIVE_DEST_ADDR) != 0) {
		if (msgControlLen < CMSG_SPACE(sizeof(struct in_addr)))
			return B_NO_MEMORY;

		cmsghdr* messageHeader = (cmsghdr*)msgControl;
		messageHeader->cmsg_len = CMSG_LEN(sizeof(struct in_addr));
		messageHeader->cmsg_level = IPPROTO_IP;
		messageHeader->cmsg_type = IP_RECVDSTADDR;

		memcpy(CMSG_DATA(messageHeader),
		 	&((struct sockaddr_in*)buffer->destination)->sin_addr,
		 	sizeof(struct in_addr));

		bytesWritten += CMSG_SPACE(sizeof(struct in_addr));
	}

	return bytesWritten;
}


//	#pragma mark -


status_t
init_ipv4()
{
	sPacketID = (int32)system_time();

	mutex_init(&sRawSocketsLock, "raw sockets");
	mutex_init(&sFragmentLock, "IPv4 Fragments");
	mutex_init(&sMulticastGroupsLock, "IPv4 multicast groups");
	mutex_init(&sReceivingProtocolLock, "IPv4 receiving protocols");

	status_t status;

	sMulticastState = new MulticastState();
	if (sMulticastState == NULL) {
		status = B_NO_MEMORY;
		goto err4;
	}

	status = sMulticastState->Init();
	if (status != B_OK)
		goto err5;

	new (&sFragmentHash) FragmentTable();
	status = sFragmentHash.Init(256);
	if (status != B_OK)
		goto err5;

	new (&sRawSockets) RawSocketList;
		// static initializers do not work in the kernel,
		// so we have to do it here, manually
		// TODO: for modules, this shouldn't be required

	status = gStackModule->register_domain_protocols(AF_INET, SOCK_RAW, 0,
		"network/protocols/ipv4/v1", NULL);
	if (status != B_OK)
		goto err6;

	status = gStackModule->register_domain(AF_INET, "internet", &gIPv4Module,
		&gIPv4AddressModule, &sDomain);
	if (status != B_OK)
		goto err6;

	add_debugger_command("ipv4_multicast", dump_ipv4_multicast,
		"list all current IPv4 multicast states");

	return B_OK;

err6:
	sFragmentHash.~FragmentTable();
err5:
	delete sMulticastState;
err4:
	mutex_destroy(&sReceivingProtocolLock);
	mutex_destroy(&sMulticastGroupsLock);
	mutex_destroy(&sFragmentLock);
	mutex_destroy(&sRawSocketsLock);
	return status;
}


status_t
uninit_ipv4()
{
	mutex_lock(&sReceivingProtocolLock);

	remove_debugger_command("ipv4_multicast", dump_ipv4_multicast);

	// put all the domain receiving protocols we gathered so far
	for (uint32 i = 0; i < 256; i++) {
		if (sReceivingProtocol[i] != NULL)
			gStackModule->put_domain_receiving_protocol(sDomain, i);
	}

	gStackModule->unregister_domain(sDomain);
	mutex_unlock(&sReceivingProtocolLock);

	delete sMulticastState;
	sFragmentHash.~FragmentTable();

	mutex_destroy(&sMulticastGroupsLock);
	mutex_destroy(&sFragmentLock);
	mutex_destroy(&sRawSocketsLock);
	mutex_destroy(&sReceivingProtocolLock);

	return B_OK;
}


static status_t
ipv4_std_ops(int32 op, ...)
{
	switch (op) {
		case B_MODULE_INIT:
			return init_ipv4();
		case B_MODULE_UNINIT:
			return uninit_ipv4();

		default:
			return B_ERROR;
	}
}


net_protocol_module_info gIPv4Module = {
	{
		"network/protocols/ipv4/v1",
		0,
		ipv4_std_ops
	},
	NET_PROTOCOL_ATOMIC_MESSAGES,

	ipv4_init_protocol,
	ipv4_uninit_protocol,
	ipv4_open,
	ipv4_close,
	ipv4_free,
	ipv4_connect,
	ipv4_accept,
	ipv4_control,
	ipv4_getsockopt,
	ipv4_setsockopt,
	ipv4_bind,
	ipv4_unbind,
	ipv4_listen,
	ipv4_shutdown,
	ipv4_send_data,
	ipv4_send_routed_data,
	ipv4_send_avail,
	ipv4_read_data,
	ipv4_read_avail,
	ipv4_get_domain,
	ipv4_get_mtu,
	ipv4_receive_data,
	ipv4_deliver_data,
	ipv4_error_received,
	ipv4_error_reply,
	NULL,		// add_ancillary_data()
	NULL,		// process_ancillary_data()
	ipv4_process_ancillary_data_no_container,
	NULL,		// send_data_no_buffer()
	NULL		// read_data_no_buffer()
};

module_dependency module_dependencies[] = {
	{NET_STACK_MODULE_NAME, (module_info**)&gStackModule},
	{NET_BUFFER_MODULE_NAME, (module_info**)&gBufferModule},
	{NET_DATALINK_MODULE_NAME, (module_info**)&sDatalinkModule},
	{NET_SOCKET_MODULE_NAME, (module_info**)&sSocketModule},
	{}
};

module_info* modules[] = {
	(module_info*)&gIPv4Module,
	NULL
};
